package main

import (
	"fmt"
	"os"
	"strconv"
	"strings"

	"go.flipt.io/flipt/rpc/flipt/options"
	"google.golang.org/genproto/googleapis/api/annotations"
	"google.golang.org/genproto/googleapis/api/serviceconfig"
	"google.golang.org/protobuf/compiler/protogen"
	"google.golang.org/protobuf/encoding/protojson"
	"google.golang.org/protobuf/proto"
	"google.golang.org/protobuf/reflect/protoreflect"
	"google.golang.org/protobuf/types/descriptorpb"
	"sigs.k8s.io/yaml"
)

const httpbodyImport = "google.golang.org/genproto/googleapis/api/httpbody"

func generateHTTP(gen *protogen.Plugin, grpcAPIConfig string, importPath protogen.GoImportPath) {
	data, err := os.ReadFile(grpcAPIConfig)
	if err != nil {
		panic(err)
	}

	json, err := yaml.YAMLToJSON(data)
	if err != nil {
		panic(err)
	}

	var config serviceconfig.Service
	if err := (protojson.UnmarshalOptions{DiscardUnknown: true}).Unmarshal(json, &config); err != nil {
		panic(err)
	}

	m := mappings{}

	for _, f := range gen.Files {
		if !f.Generate {
			continue
		}
		for _, s := range f.Services {
			for _, m := range s.Methods {
				if shouldIgnoreMethod(m) {
					continue
				}

				if rule := proto.GetExtension(m.Desc.Options(), annotations.E_Http).(*annotations.HttpRule); rule != nil {
					rule.Selector = string(m.Desc.FullName())
					config.Http.Rules = append(config.Http.Rules, rule)
				}
			}
		}
	}

	for _, r := range config.Http.Rules {
		rule := rule{
			body: r.Body == "*",
		}

		switch r.Pattern.(type) {
		case *annotations.HttpRule_Get:
			rule.method = "MethodGet"
			rule.path = r.GetGet()
		case *annotations.HttpRule_Post:
			rule.method = "MethodPost"
			rule.path = r.GetPost()
		case *annotations.HttpRule_Put:
			rule.method = "MethodPut"
			rule.path = r.GetPut()
		case *annotations.HttpRule_Delete:
			rule.method = "MethodDelete"
			rule.path = r.GetDelete()
		default:
			fmt.Fprintf(os.Stderr, "unsupported pattern: %T\n", r.Pattern)
		}
		m[r.Selector] = rule
	}

	g := gen.NewGeneratedFile("http/http.sdk.gen.go", importPath+"/http")
	g.P("// Code generated by protoc-gen-go-flipt-sdk. DO NOT EDIT.")
	g.P()
	g.P("package http")
	g.P()

	var (
		sdk     = importPackage(g, importPath)
		netHTTP = importPackage(g, "net/http")

		metadata  = importPackage(g, "google.golang.org/grpc/metadata")
		protojson = importPackage(g, "google.golang.org/protobuf/encoding/protojson")
		pbstatus  = importPackage(g, "google.golang.org/genproto/googleapis/rpc/status")
		status    = importPackage(g, "google.golang.org/grpc/status")
	)

	g.P("var _ ", sdk("Transport"), " = Transport{}")

	g.P("type Transport struct {")
	g.P("client *", netHTTP("Client"))
	g.P("addr string")
	g.P("}\n")

	g.P("type Option func(*Transport)")
	g.P("func WithHTTPClient(client *", netHTTP("Client"), ") Option {")
	g.P("return func(t *Transport) { t.client = client }")
	g.P("}\n")

	g.P("func NewTransport(addr string, opts ...Option) Transport {")
	g.P("t := Transport{")
	g.P("client: &", netHTTP("Client"), "{ Transport: http.DefaultTransport },")
	g.P("addr: addr,")
	g.P("}")
	g.P("for _, opt := range opts { opt(&t) }")
	g.P("transport := t.client.Transport")
	g.P("t.client.Transport = roundTripFunc(func(r *", netHTTP("Request"), ") (*", netHTTP("Response"), ", error) {")
	g.P("md, ok := ", metadata("FromOutgoingContext"), "(r.Context())")
	g.P("if ok {")
	g.P(`if auth := md.Get("authorization"); len(auth) > 0 {`)
	g.P(`r.Header.Set("Authorization", auth[0])`)
	g.P("}")
	g.P("}\n")
	g.P("return transport.RoundTrip(r)")
	g.P("})\n")
	g.P("return t")
	g.P("}\n")

	g.P("func checkResponse(resp *", netHTTP("Response"), ", v []byte) error {")
	g.P("if resp.StatusCode != ", netHTTP("StatusOK"), "{")
	g.P("var status ", pbstatus("Status"))
	g.P("if err := ", protojson("Unmarshal"), "(v, &status); err != nil { return err }")
	g.P("return ", status("ErrorProto"), "(&status)")
	g.P("}\n")
	g.P("return nil")
	g.P("}\n")

	g.P("type roundTripFunc func(r *", netHTTP("Request"), ") (*", netHTTP("Response"), ", error)")
	g.P("func (f roundTripFunc) RoundTrip(r *", netHTTP("Request"), ") (*", netHTTP("Response"), ", error) {")
	g.P("return f(r)")
	g.P("}\n")

	for _, file := range gen.Files {
		if !file.Generate {
			continue
		}

		filename := "http/" + string(file.GoPackageName) + ".sdk.gen.go"
		g := gen.NewGeneratedFile(filename, importPath+"/http")
		g.P("// Code generated by protoc-gen-go-flipt-sdk. DO NOT EDIT.")
		g.P()
		g.P("package http")
		g.P()

		var (
			typ    = strings.Title(string(file.GoPackageName))
			method = typ + "Client"

			// imported packages
			sdk     = importPackage(g, importPath)
			netHTTP = importPackage(g, "net/http")
		)

		if len(file.Services) == 1 {
			srv := file.Services[0]
			returnType := srv.GoName + "Client"

			g.P("type ", returnType, " struct {")
			g.P("client *", netHTTP("Client"))
			g.P("addr string")
			g.P("}\n")

			for _, method := range srv.Methods {
				if shouldIgnoreMethod(method) {
					continue
				}
				generateHTTPMethod(g, m, method, returnType)
			}

			g.P("func (t Transport) ", method, "() ", relativeImport(g, file, returnType), "{")
			g.P("return &", returnType, "{ client: t.client, addr: t.addr }")
			g.P("}\n")
			continue
		}

		// the following handles bundling together packages containing more than
		// one service definition into a single unexported type which implements
		// the combined client interface the SDK generator produces.
		groupType := unexport(method)
		g.P("type ", groupType, " struct {")
		g.P("client *", netHTTP("Client"))
		g.P("addr string")
		g.P("}\n")

		for _, srv := range file.Services {
			var (
				returnType           = srv.GoName + "Client"
				unexportedReturnType = unexport(returnType)
			)

			g.P("func (t ", groupType, ") ", returnType, "() ", relativeImport(g, file, returnType), " {")
			g.P("return &", unexportedReturnType, "{ client: t.client, addr: t.addr }")
			g.P("}\n")

			g.P("type ", unexportedReturnType, " struct {")
			g.P("client *", netHTTP("Client"))
			g.P("addr string")
			g.P("}\n")

			for _, method := range srv.Methods {
				if shouldIgnoreMethod(method) {
					continue
				}
				generateHTTPMethod(g, m, method, unexportedReturnType)
			}
		}

		g.P("func (t Transport) ", method, "() ", sdk(method), "{")
		g.P("return ", groupType, "{client: t.client, addr: t.addr}")
		g.P("}\n")

	}
}

func generateHTTPMethod(g *protogen.GeneratedFile, m mappings, method *protogen.Method, typ string) {
	rule, ok := m[string(method.Desc.FullName())]
	if !ok {
		return
	}

	var (
		context = importPackage(g, "context")
		io      = importPackage(g, "io")
		bytes   = importPackage(g, "bytes")
		pkgfmt  = importPackage(g, "fmt")

		netURL  = importPackage(g, "net/url")
		netHTTP = importPackage(g, "net/http")

		grpc      = importPackage(g, "google.golang.org/grpc")
		protojson = importPackage(g, "google.golang.org/protobuf/encoding/protojson")
	)

	g.P("func (x *", typ, ") ", method.GoName, "(ctx ", context("Context"), ", v *", method.Input.GoIdent, ", _ ..."+grpc("CallOption")+") (*", method.Output.GoIdent, ", error) {")

	g.P("var body ", io("Reader"))
	path, inPath := renderPath(pkgfmt, rule, method.Input)
	if rule.body {
		g.P("var values ", netURL("Values"))
		g.P("reqData, err := ", protojson("Marshal"), "(v)")
		g.P("if err != nil { return nil, err }")
		g.P("body = ", bytes("NewReader"), "(reqData)")
	} else {
		var (
			setValues  []string
			hasMessage bool
		)

		var pathDefaults map[string]string
		if opts, ok := method.Desc.Options().(*descriptorpb.MethodOptions); ok {
			fc, ok := proto.GetExtension(opts, options.E_FliptClient).(*options.FliptClient)
			if ok && fc != nil {
				pathDefaults = fc.PathDefaults
			}
		}

		for _, field := range method.Input.Fields {
			if _, ok := inPath[field]; ok {
				// set default for any string path variables
				if def, ok := pathDefaults[string(field.Desc.Name())]; ok {
					switch field.Desc.Kind() {
					case protoreflect.StringKind:
						val := "v." + field.GoName
						g.P("if ", val, " == \"\" {")
						g.P(val, " = ", strconv.Quote(def))
						g.P("}")
					}
				}
			} else {
				switch field.Desc.Kind() {
				case protoreflect.StringKind:
					val := "v." + field.GoName
					setValues = append(setValues, fmt.Sprintf(`values.Set("%s", %s)`, field.Desc.JSONName(), val))
				case protoreflect.MessageKind:
					hasMessage = true
					marshal := fmt.Sprintf("field, err = protojson.Marshal(v.%s)\nif err != nil { return nil, err }\n", field.GoName)
					setValues = append(setValues, marshal, fmt.Sprintf(`values.Set("%s", unquote(field))`, field.Desc.JSONName()))
				default:
					setValues = append(setValues, fmt.Sprintf(`values.Set("%s", fmt.Sprintf("%%v", v.%s))`, field.Desc.JSONName(), field.GoName))
				}
			}
		}

		// only allocate if we have any values to set on the query
		if len(setValues) == 0 {
			g.P("var values ", netURL("Values"))
		} else {
			if hasMessage {
				g.P("var field []byte")
				g.P("var err error")
				g.P("var unquote = func(v []byte) string {")
				g.P("s, err := ", importPackage(g, "strconv")("Unquote"), "(string(v))")
				g.P("if err == nil { return s }")
				g.P("return string(v)")
				g.P("}")
			}
			g.P("values := ", netURL("Values"), "{}")
			for _, val := range setValues {
				g.P(val)
			}
		}
	}

	g.P("req, err := ", netHTTP("NewRequestWithContext"), "(ctx, ", netHTTP(rule.method), ", x.addr+", path, ", body)")
	g.P("if err != nil { return nil, err }")

	g.P("req.URL.RawQuery = values.Encode()")

	g.P("resp, err := x.client.Do(req)")
	g.P("if err != nil { return nil, err }")
	g.P("defer resp.Body.Close()")

	g.P("var output ", method.Output.GoIdent)
	g.P("respData, err := ", io("ReadAll"), "(resp.Body)")
	g.P("if err != nil { return nil, err }")

	g.P("if err := checkResponse(resp, respData); err != nil {return nil, err}")

	// httpbody just returns the entire response body
	// as-is on its Body field.
	if method.Output.GoIdent.GoImportPath == httpbodyImport {
		g.P(`output.ContentType = resp.Header.Get("Content-Type")`)
		g.P(`output.Data = respData`)
		g.P("return &output, nil")
		g.P("}\n")
		return
	}

	g.P("if err := (", protojson("UnmarshalOptions"), "{DiscardUnknown: true}).Unmarshal", "(respData, &output); err != nil { return nil, err }")
	g.P("return &output, nil")
	g.P("}\n")
}

type rule struct {
	method string
	path   string
	body   bool
}

type mappings map[string]rule

func renderPath(pkgfmt func(string) string, rule rule, msg *protogen.Message) (string, map[*protogen.Field]struct{}) {
	var (
		args   []string
		inPath = map[*protogen.Field]struct{}{}
	)

	v := `"/`
	parts := strings.Split(rule.path, "/")
	for i, p := range parts {
		if p == "" {
			continue
		}

		if p[0] == '{' && p[len(p)-1] == '}' {
			for _, field := range msg.Fields {
				if fieldName, _, _ := strings.Cut(p[1:len(p)-1], "="); fieldName == string(field.Desc.Name()) {
					p = "%v"
					args = append(args, "v."+field.GoName)
					inPath[field] = struct{}{}
				}
			}
		}
		v += p

		if i < len(parts)-1 {
			v += "/"
		}
	}

	if len(args) > 0 {
		return fmt.Sprintf("%s(%s, %s)", pkgfmt("Sprintf"), v+`"`, strings.Join(args, ",")), inPath
	}

	return v + `"`, inPath
}
